#include "Simulator.h"

//////////////////////////////////////////////////////////////////////////
//Author	:	Ross Conroy ross.conroy@tees.ac.uk
//Date		:	05/08/2014
//
//This class runs simulations on DID and I-DID models using random number
//generation (Monte Carlo simulation)
//
//Update 08/08/2014
//Made code more modular to simulating I-DIDs
//////////////////////////////////////////////////////////////////////////

Simulator::Simulator(void)
{
	time_t timeVal=0;
	time(&timeVal);
	srand(timeVal);
}

//////////////////////////////////////////////////////////////////////////
//This method runs simulations on DID models suing the following algorithm
//
//FOR i = 0 TO i = number of simulations DO
//	choose a random model
//	FOR t = 1 TO t = number of time steps Do
//		record time step t
//		choose random state st using its probability distribution
//		record state of st
//		propagate evidence
//		choose random observation ot using its probability distribution
//		record observation ot
//		remove evidence st TO s1
//		choose action at
//		record action at
//		re-enter evidence s1 TO s2
//		record reward at ut
//	END FOR
//END FOR
//////////////////////////////////////////////////////////////////////////
list<DidResult> Simulator::SimulateDID(list<Domain*> inputModels, string statePrefix, string observationPrefix, 
										string actionPrefix, string utilityPrefix, int numberOfSimulations)
{
	list<DidResult> results;

	//Compile all domains
	for(list<Domain*>::iterator it = inputModels.begin(); it != inputModels.end(); it++)
	{
		Domain * domain = *it;
		domain->compile();
		//domain->updatePolicies();
	}

	for (int i = 0; i < numberOfSimulations; i++)
	{
		double percentageComplete = ((double)i / (double)numberOfSimulations) * (double)100;
		cout << "\rProgress:              " << flush;
		cout << "\rProgress: " << percentageComplete << "%" << flush;


		Domain * currentModel = ChooseRandomModel(inputModels);
		ExpandTimeSteps expander;
		int numSteps = expander.FindLastTimeStep(currentModel);

		hash_map<string, int> pastStates;

		for(int t = 1; t <= numSteps; t++)
		{
			//GetNodes associated with current time step
			DIDTimeStepNodes timeStepNodes = GetDIDNodesForTimeStep(currentModel, t, statePrefix, observationPrefix, actionPrefix, utilityPrefix);
			
			//Choose and record state st
			int stateIndex = ChooseRandomStateForNode(timeStepNodes.stateNode);
			//timeStepNodes.stateNode->enterFinding(stateIndex, 1);
			SetNodeFindings(timeStepNodes.stateNode, stateIndex);

			string stateName = timeStepNodes.stateNode->getStateLabel(stateIndex);
			pastStates.insert(make_pair(timeStepNodes.stateNode->getName(), stateIndex));
			currentModel->propagate();

			string observationName;
			//no observation for t = 1
			if(t != 1)
			{
				//choose and record random observation
				int observationIndex = ChooseRandomStateForNode(timeStepNodes.observationNode);
				//timeStepNodes.observationNode->selectState(observationIndex);
				SetNodeFindings(timeStepNodes.observationNode, observationIndex);
				observationName = timeStepNodes.observationNode->getStateLabel(observationIndex);
				currentModel->propagate();
			}
			else
			{
				observationName = "-";
				stateName = "-";
			}
			

			//remove evidence s1 to st -- loop through hash_set of past states
			ClearNodes(pastStates, currentModel);

			//Choose action using policy for action node, can use random method again??
			int actionIndex = ChooseDecisionForNode(timeStepNodes.decisionNode);
			//timeStepNodes.decisionNode->selectState(actionIndex);
			SetNodeFindings(timeStepNodes.decisionNode, actionIndex);
			string actionName = timeStepNodes.decisionNode->getStateLabel(actionIndex);
			currentModel->propagate();

			//re-enter evidence s1 to st -- loop through hash_set of past states
			ResetNodes(pastStates, currentModel);	

			//record utility for ut 
			double utility = timeStepNodes.utilityNode->getExpectedUtility();

			DidResult didResult;
			didResult.model = currentModel->getAttribute("model");
			didResult.timeStep = t;
			didResult.state = stateName;
			didResult.observation = observationName;
			didResult.action = actionName;
			didResult.utility = utility;

			results.push_back(didResult);
		}
	}

	cout << "\r                       " << flush;
	cout << "\r" << flush;

	return results;
}

//////////////////////////////////////////////////////////////////////////
//Simulates an I-DID agent playing against a DID agent using the following
//steps
//
//FOR Each DID Model IN Models DO
//	Add model to I-DID
//END FOR
//
//FOR i = 0 TO i = number of simulations DO
//	choose a random DID model
//	FOR t = 1 TO t = number of time steps Do
//		record time step t
//		choose random state st from I-DID using its probability distribution
//		record state of st I-DID		
//
//		enter state st into DID
//		propagate evidence DID
//
//		choose random observation ot in DID using its probability distribution
//		remove evidence st TO s1 in DID
//		choose at from DID
//		re-enter evidence s1 TO s2 into DID
//		propagate evidence DID
//		
//		enter at from DID into I-DID aj
//
//		propagate evidence I-DID
//		choose random observation ot from I-DID using its probability distribution
//		record observation ot
//		remove evidence st TO s1 in I-DID
//		remove evidence ajt to aj1 in I-DID
//		choose action at in I-DID
//		record action at in I-DID
//		re-enter evidence s1 TO s2 in I-DID
//		re-enter evidence aj1 to ajt in I-DID
//		propagate evidence in I-DID
//		record reward at ut in I-DID
//	END FOR
//END FOR
//////////////////////////////////////////////////////////////////////////
list<IDidResult> Simulator::SimulateIDID(Domain * inputIdid, list<Domain*> inputModels, string statePrefix, string utilityPrefix,
											string iObservationPrefix, string jObservationPrefix, string iActionPrefix, 
											string jActionPrefix, int numberOfSimulations, bool modelWeighting, string modelPrefix)
{
	list<IDidResult> results;

	ExpandTimeSteps expander;
	int numSteps = expander.FindLastTimeStep(inputIdid);

	//Compile and solve all domains
	for(list<Domain*>::iterator it = inputModels.begin(); it != inputModels.end(); it++)
	{
		Domain * domain = *it;
		domain->compile();
		//domain->updatePolicies();
	}

	inputIdid->compile();
	inputIdid->retractFindings();
	inputIdid->propagate();

	//Run simulations
	for(int i = 0; i < numberOfSimulations; i++)
	{
		double percentageComplete = ((double)i / (double)numberOfSimulations) * (double)100;
		cout << "\rProgress:              " << flush;
		cout << "\rProgress: " << percentageComplete << "%" << flush;
		currentModel = -1;

		//work with a clone of I-DID and choose a random model to simulate
		//Domain * iDIDModel = inputIdid->clone();
		Domain * iDIDModel = inputIdid;
		Domain * didModel = ChooseRandomModel(inputModels);
		
		//iDIDModel->compile();
		iDIDModel->retractFindings();
		//iDIDModel->updatePolicies();
		iDIDModel->propagate();

		hash_map<string, int> pastStates; //<nodeName, stateIndex>
		hash_map<string, int> pastOponentActions;

		

		//cout << "--------" << endl;
		try
		{
			//simulate each time step
			for(int t = 1; t <= numSteps; t++)
			{
				//get nodes from both models
				IDIDTimeStepNodes timeStepNodesIDID = GetIDIDNodesForTimeStep(iDIDModel, t, statePrefix, iObservationPrefix, iActionPrefix, utilityPrefix, jActionPrefix, modelPrefix);
				DIDTimeStepNodes timeStepNodesDID = GetDIDNodesForTimeStep(didModel, t, statePrefix, jObservationPrefix, jActionPrefix, utilityPrefix);
					

				//Choose random state and apply to both models
				int stateIndex = ChooseRandomStateForNode(timeStepNodesIDID.stateNode);
				//int stateIndex = ChooseRandomStateForNode(timeStepNodesDID.stateNode);
				//timeStepNodesIDID.stateNode->selectState(stateIndex);
				//timeStepNodesDID.stateNode->selectState(stateIndex);
				//timeStepNodesIDID.stateNode->enterFinding(stateIndex, 1);
				//timeStepNodesDID.stateNode->enterFinding(stateIndex, 1);
				SetNodeFindings(timeStepNodesIDID.stateNode, stateIndex);
				SetNodeFindings(timeStepNodesDID.stateNode, stateIndex);

				string stateName = timeStepNodesIDID.stateNode->getStateLabel(stateIndex);
				//cout << stateName << endl;
				pastStates.insert(make_pair(timeStepNodesIDID.stateNode->getName(), stateIndex));
			
			
				iDIDModel->propagate();
				didModel->propagate();
						
				string observationName;

				if(t != 1)
				{
					//choose and record random observation for both models
					int observationIndexIDID = ChooseRandomStateForNode(timeStepNodesIDID.observationNode);
					//timeStepNodesIDID.observationNode->selectState(observationIndexIDID);
					//timeStepNodesIDID.observationNode->enterFinding(observationIndexIDID, 1);
					observationName = timeStepNodesIDID.observationNode->getStateLabel(observationIndexIDID);
					int observationIndexDID = ChooseRandomStateForNode(timeStepNodesDID.observationNode);
					//timeStepNodesDID.observationNode->selectState(observationIndexDID);
					//timeStepNodesDID.observationNode->enterFinding(observationIndexDID, 1);

					SetNodeFindings(timeStepNodesIDID.observationNode, observationIndexIDID);
					SetNodeFindings(timeStepNodesDID.observationNode, observationIndexDID);

					iDIDModel->propagate();
					didModel->propagate();
				}
				else
				{
					observationName = "-";
					//stateName = "-";

					if(modelWeighting)
					{
						ApplyModelWeighting(timeStepNodesIDID.modelNode);
						iDIDModel->propagate();
					}
				}

				//remove state evidence from both models			
				ClearNodes(pastStates, iDIDModel);
				ClearNodes(pastStates, didModel);
				ClearNodes(pastOponentActions, iDIDModel);
				//timeStepNodesDID.stateNode->retractFindings();
				//timeStepNodesIDID.stateNode->retractFindings();
				iDIDModel->propagate();
				didModel->propagate();
			
				//Choose action using policy for action node in each model
				int actionIndexIDID = ChooseDecisionForNode(timeStepNodesIDID.decisionNode);
				timeStepNodesIDID.decisionNode->selectState(actionIndexIDID);
				//timeStepNodesIDID.decisionNode->enterFinding(actionIndexIDID, 1);
				string actionName = timeStepNodesIDID.decisionNode->getStateLabel(actionIndexIDID);
				//cout << "action " << actionName << endl;
				//iDIDModel->propagate();	

				//ResetNodes(pastOponentActions, iDIDModel);

				int actionIndexDID = ChooseDecisionForNode(timeStepNodesDID.decisionNode);	
				string opponentActionName = timeStepNodesDID.decisionNode->getStateLabel(actionIndexDID);
				//cout << "opp action " << opponentActionName << endl;
				timeStepNodesDID.decisionNode->selectState(actionIndexDID);
				//timeStepNodesIDID.oponentActionNode->selectState(actionIndexDID);
				//timeStepNodesIDID.oponentActionNode->enterFinding(stateIndex, 1);
				//timeStepNodesDID.decisionNode->enterFinding(stateIndex, 1);

				SetNodeFindings(timeStepNodesIDID.oponentActionNode, actionIndexDID);

				pastOponentActions.insert(make_pair(timeStepNodesIDID.oponentActionNode->getName(), actionIndexDID));			
			
				//re enter state evidence for both models
				//cout << "DID" << endl;
				//ResetNodes(pastStates, didModel);
				//cout << "I-DID" << endl;
				//ResetNodes(pastStates, iDIDModel);				
			
				//ResetNodes(pastOponentActions, iDIDModel);
				//timeStepNodesIDID.stateNode->selectState(stateIndex);
				//timeStepNodesDID.stateNode->selectState(stateIndex);
				//timeStepNodesIDID.stateNode->enterFinding(stateIndex, 1);
				//timeStepNodesDID.stateNode->enterFinding(stateIndex, 1);
				SetNodeFindings(timeStepNodesIDID.stateNode, stateIndex);
				SetNodeFindings(timeStepNodesDID.stateNode, stateIndex);

				didModel->propagate();
				iDIDModel->propagate();	

				//Get utility
				double utility = timeStepNodesIDID.utilityNode->getExpectedUtility();

				ResetNodes(pastOponentActions, iDIDModel);

				IDidResult iDidResult;
				iDidResult.oponentModel = didModel->getAttribute("model");
				iDidResult.timeStep = t;
				iDidResult.state = stateName;
				iDidResult.observation = observationName;
				iDidResult.action = actionName;
				iDidResult.oponentAction = opponentActionName;
				iDidResult.utility = utility;

				results.push_back(iDidResult);
					
			}

		}
		catch (ExceptionHugin ex)
		{
			//cout << "caught - " << ex.what() << endl;
			IDidResult iDidResult;
			iDidResult.oponentModel = "ER";
			iDidResult.timeStep = numSteps;
			iDidResult.state = "ER";
			iDidResult.observation = "ER";
			iDidResult.action = "ER";
			iDidResult.oponentAction = "ER";
			iDidResult.utility = -100;

			results.push_back(iDidResult);
		}
	}

	cout << "\r                       " << flush;
	cout << "\r" << flush;

	return results;
}

//////////////////////////////////////////////////////////////////////////
//Sets evidence of selected state to 1 and all other states to 0;
//////////////////////////////////////////////////////////////////////////
void Simulator::SetNodeFindings(DiscreteNode * node, int state)
{
	for(int i = 0; i < node->getNumberOfStates(); i++)
	{
		if(i == state)
		{
			node->enterFinding(i, 1);
		}
		else
		{
			node->enterFinding(i, 0);
		}
	}
}

//Resets nodes to their chosen values
void Simulator::ResetNodes(hash_map<string, int> nodes, Domain * inputID)
{
	for(hash_map<string, int>::reverse_iterator its = nodes.rbegin(); its!= nodes.rend(); its++)
	{
		string nodeName = its->first;
		DiscreteNode * sNode = (DiscreteNode*)inputID->getNodeByName(nodeName);
		//sNode->selectState(its->second);
		//sNode->enterFinding(its->second, 1);
		SetNodeFindings(sNode, its->second);
		inputID->propagate();
	}	
}

//clears all evidence for a list of nodes
void Simulator::ClearNodes(hash_map<string, int> nodes, Domain * inputID)
{
	for(hash_map<string, int>::const_iterator its = nodes.begin(); its!= nodes.end(); its++)
	{	
		string nodeName = its->first;
		DiscreteNode * sNode = (DiscreteNode*)inputID->getNodeByName(nodeName);
		sNode->retractFindings();
		double s1 = sNode->getBelief(0);
		double s2 = sNode->getBelief(1);
		inputID->propagate();	
	}
}

//chooses a random model from a list of models
Domain * Simulator::ChooseRandomModel(list<Domain*> models)
{
	int rModel = rand() % models.size();
	currentModel = rModel;
	list<Domain*>::iterator itm = models.begin();
	advance(itm, rModel);

	stringstream ss;
	ss << rModel;

	Domain * tmp = *itm;
	//Domain * currentModel = tmp->clone();
	Domain * currentModel = tmp;

	currentModel->setAttribute("model", ss.str());

	//currentModel->compile();
	currentModel->retractFindings();
	currentModel->propagate();

	return currentModel;
}

//sets the weighting of the model node
void Simulator::ApplyModelWeighting(DiscreteNode * modelNode)
{
	stringstream ss;
	ss << currentModel;
	string modelState = ss.str();

	int index = modelNode->getStateIndex(modelState);

	if(index > -1)
	{
		SetNodeFindings(modelNode, index);
	}
}

//Gets nodes for time step
DIDTimeStepNodes Simulator::GetDIDNodesForTimeStep(Domain * inputID, int timeStep, string statePrefix, string observationPrefix, 
												string decisionPrefix, string utilityPrefix)
{
	ExpandTimeSteps expander;
	NodeList nl = expander.GetNodesForTimeStep(inputID, timeStep);

	DiscreteNode * stateNode = (DiscreteNode*)FindNodeWithPrefix(nl, statePrefix);
	DiscreteNode * observationNode = (DiscreteNode*)FindNodeWithPrefix(nl, observationPrefix);
	DiscreteDecisionNode * decisionNode = (DiscreteDecisionNode*)FindNodeWithPrefix(nl, decisionPrefix);
	UtilityNode * utilityNode = (UtilityNode*)FindNodeWithPrefix(nl, utilityPrefix);

	DIDTimeStepNodes tsn;
	tsn.stateNode = stateNode;
	tsn.observationNode = observationNode;
	tsn.decisionNode = decisionNode;
	tsn.utilityNode = utilityNode;

	return tsn;
}

IDIDTimeStepNodes Simulator::GetIDIDNodesForTimeStep(Domain * inputID, int timeStep, string statePrefix, string observationPrefix, string decisionPrefix, string utilityPrefix, string jActionPrefix, string modelPrefix)
{
	ExpandTimeSteps expander;
	NodeList nl = expander.GetNodesForTimeStep(inputID, timeStep);

	DiscreteNode * stateNode = (DiscreteNode*)FindNodeWithPrefix(nl, statePrefix);
	DiscreteNode * observationNode = (DiscreteNode*)FindNodeWithPrefix(nl, observationPrefix);
	DiscreteDecisionNode * decisionNode = (DiscreteDecisionNode*)FindNodeWithPrefix(nl, decisionPrefix);
	UtilityNode * utilityNode = (UtilityNode*)FindNodeWithPrefix(nl, utilityPrefix);
	DiscreteNode * oponentActionNode = (DiscreteNode*)FindNodeWithPrefix(nl, jActionPrefix);
	DiscreteNode * modelNode = (DiscreteNode*)FindNodeWithPrefix(nl, modelPrefix);

	IDIDTimeStepNodes tsn;
	tsn.stateNode = stateNode;
	tsn.observationNode = observationNode;
	tsn.decisionNode = decisionNode;
	tsn.utilityNode = utilityNode;
	tsn.oponentActionNode = oponentActionNode;
	tsn.modelNode = modelNode;

	return tsn;
}

//Chooses a random state for a node from its probability distribution
int Simulator::ChooseRandomStateForNode(DiscreteNode * node)
{	
	double r = ((double) rand() / (RAND_MAX + 1));

	//double b = node->getBelief(0);
	//double c = node->getBelief(1);
	
	for(int i = 0; i < node->getNumberOfStates(); i++)
	{
		double b = node->getBelief(i);

		//if = 1 then no point using random numbers
		if(b == 1)
		{
			return i;
		}

		r-= node->getBelief(i);
		
		if(r <= 0)
		{
			return i;
		}
	}

	return 0;
}

//chooses a decision for a decision node based on what decision has the
//highest expected utility
int Simulator::ChooseDecisionForNode(DiscreteDecisionNode * node)
{
	double maxUtil;
	int maxDecision = 0;

	//cout << node->getName() << endl;

	for(int i = 0; i < node->getNumberOfStates(); i++)
	{
		if(i == 0)
		{
			maxUtil = node->getExpectedUtility(i);
		}
		else if(node->getExpectedUtility(i) > maxUtil)
		{
			maxUtil = node->getExpectedUtility(i);
			maxDecision = i;
		}

		//cout << node->getStateLabel(i) << " " << node->getExpectedUtility(i) << endl;
	}

	return maxDecision;
}

//Finds nodes with a given prefix from a list of nodes
Node * Simulator::FindNodeWithPrefix(NodeList nodes, string prefix)
{
	for (NodeList::const_iterator it = nodes.begin(); it != nodes.end(); it++)
	{
		Node * node = *it;
		string nodeName = node->getName();
		if(nodeName.find(prefix) == 0)
		{
			return node;
		}
	}
}

//Saves the DID output to a CSV file
void Simulator::SaveDIDToCsvFile(list<DidResult>results, string fileName)
{
	ofstream out(fileName.c_str());

	out << "Model,TimeStep,State,Observation,Action,Utility," << endl;

	for(list<DidResult>::iterator it = results.begin(); it != results.end(); it++)
	{
		DidResult result = *it;

		out << result.model << ",";
		out << result.timeStep << ",";
		out << result.state << ",";
		out << result.observation << ",";
		out << result.action << ",";
		out << result.utility << "," << endl;;
	}

	out.close();
}

//save the I-DID output to csv file
void Simulator::SaveIDIDToCsvFile(list<IDidResult>results, string fileName)
{
	bool exists = false;
	ifstream f(fileName.c_str());
	if (f.good()) 
	{
		exists =  true;
	}
	f.close();

	//ofstream out(fileName.c_str());
	ofstream out;
	out.open (fileName.c_str(), std::ofstream::out | std::ofstream::app);

	if(!exists)
	{
		out << "OponentModel,TimeStep,State,Observation,Action,OponentAction,Utility," << endl;
	}

	for(list<IDidResult>::iterator it = results.begin(); it != results.end(); it++)
	{
		IDidResult result = *it;

		out << result.oponentModel << ",";
		out << result.timeStep << ",";
		out << result.state << ",";
		out << result.observation << ",";
		out << result.action << ",";
		out << result.oponentAction << ",";
		out << result.utility << "," << endl;;
	}

	out.close();
}
